#!/usr/bin/env python3

import argparse
import os
import sys
from glob import glob

import numpy as np
import torch

sys.path.insert(1, os.path.join(sys.path[0], ('..')))

import colorization.config as config
from colorization.modules.cross_entropy_loss_2d import CrossEntropyLoss2d
from colorization.util.argparse import nice_help_formatter
from colorization.util.progress import display_progress


USAGE = \
"""profile_loss [-h|--help]
                    --config CONFIG
                    --model-dir MODEL_DIR
                    --val-dir VAL_DIR
                    [--default-config BATCH_LIMIT]
                    [--checkpoint-step CHECKPOINT_STEP]
                    [--batch-limit BATCH_LIMIT]
                    [--verbose]"""


if __name__ == '__main__':
    # parse command line arguments
    parser = argparse.ArgumentParser(formatter_class=nice_help_formatter(),
                                     usage=USAGE)

    parser.add_argument('--config',
                        required=True,
                        help="same configuration file used during training")

    parser.add_argument('--model-dir',
                        required=True,
                        help="checkpoint directory")

    parser.add_argument('--val-dir',
                        required=True,
                        help="validation data directory")

    parser.add_argument('--default-config',
                        help=str("same default configuration file used during "
                                 "training"))

    parser.add_argument('--checkpoint-step',
                        type=int,
                        default=1,
                        help=str("set this to n to only consider every nth "
                                 "checkpoint"))

    parser.add_argument('--batch-limit',
                        type=int,
                        help="batch limit")

    parser.add_argument('--verbose',
                        action='store_true',
                        help="display progress")

    args = parser.parse_args()

    # load configuration
    cfg = config.get_config(args.config, args.default_config)

    config.modify_config(
        cfg, 'dataloader/params/dataset/params/root', args.val_dir)

    cfg = config.parse_config(cfg)

    # construct network
    model = cfg['model']

    model.network.eval()

    # dataset
    dataloader = cfg['dataloader']

    # loss
    loss = CrossEntropyLoss2d()

    # determine average validation loss for each intermediat model checkpoint
    checkpoints = sorted(glob(os.path.join(args.model_dir, 'checkpoint_*')))
    checkpoints = checkpoints[(args.checkpoint_step - 1)::args.checkpoint_step]

    for checkpoint in checkpoints:
        # load checkpoint
        print("{}:".format(checkpoint))
        model.load_checkpoint(checkpoint)

        # calculate batch losses
        if args.batch_limit is None:
            num_batches = len(dataloader)
        else:
            num_batches = min(len(dataloader), args.batch_limit)

        losses = []
        with torch.no_grad():
            for i, batch in enumerate(dataloader):
                if args.verbose:
                    display_progress(i, num_batches, msg="processing batch")

                batch = batch.to(model.device,
                                 non_blocking=dataloader.pin_memory)

                q_pred, q_actual = model.network._forward_encode(batch)
                losses.append(loss(q_pred, q_actual).item())

                if i == num_batches - 1:
                    break

        print("=> {:.3} +/- {:.3}".format(np.mean(losses), np.std(losses)))
